package top.lingkang.mm.page;

import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import top.lingkang.mm.error.MagicException;

import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;

/**
 * 查询分页拦截
 *
 * @author lingkang
 * Created by 2024/3/3
 */
@Slf4j
@Intercepts({@Signature(
        type = StatementHandler.class,
        method = "prepare",
        args = {Connection.class, Integer.class})})
public class MagicPageInterceptor implements Interceptor {

    private Field sqlField;

    public MagicPageInterceptor() {
        try {
            sqlField = BoundSql.class.getDeclaredField("sql");
            sqlField.setAccessible(true);
        } catch (Exception e) {
            throw new MagicException(e);
        }
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        PageInfo page = PageHelper.getPage();
        if (page == null || page.isComplete())
            return invocation.proceed();

        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        BoundSql boundSql = statementHandler.getBoundSql();
        if (!boundSql.getSql().toLowerCase().startsWith("select")) {
            log.warn("启用分页查询后，执行的sql不是查询sql: " + boundSql.getSql());
            return invocation.proceed();
        }
        Connection conn = (Connection) invocation.getArgs()[0];
        PageSqlHandle pageSqlHandle = PageHelper.getHandle(conn);
        PageSqlInfo sqlInfo = pageSqlHandle.handleSql(boundSql.getSql(), page.getPage(), page.getSize());
        PreparedStatement statement = conn.prepareStatement(sqlInfo.getCountSql());
        // 参数处理
        pageSqlHandle.handleParams(boundSql, statement);

        ResultSet resultSet = statement.executeQuery();
        if (resultSet.next()) {
            long total = resultSet.getLong(1);
            page.setTotal(total);
            page.setComplete(true);
            resultSet.close();
        }
        if (page.getTotal() > 0) {
            // 存在分页
            sqlField.set(boundSql, sqlInfo.getSelectSql());
        }
        return invocation.proceed();
    }

}
